"""
Phase 6: Robustness Tests
===========================
Subsample tests, alternative specifications, and sensitivity analysis.

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
Output: fiscal_dominance/output/tables/phase6_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"  N={model.n_obs:,}, {model.n_countries} countries, "
          f"R²={model.r_squared:.4f}, rho={model.rho:.3f}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    result_df['n_obs'] = model.n_obs
    result_df['n_countries'] = model.n_countries
    result_df['r_squared'] = model.r_squared
    result_df['rho'] = model.rho
    return model, result_df


# OECD members (ISO3)
OECD = [
    'AUS', 'AUT', 'BEL', 'CAN', 'CHE', 'CHL', 'COL', 'CRI', 'CZE', 'DEU',
    'DNK', 'ESP', 'EST', 'FIN', 'FRA', 'GBR', 'GRC', 'HUN', 'IRL', 'ISL',
    'ISR', 'ITA', 'JPN', 'KOR', 'LTU', 'LUX', 'LVA', 'MEX', 'NLD', 'NOR',
    'NZL', 'POL', 'PRT', 'SVK', 'SVN', 'SWE', 'TUR', 'USA',
]

# Europe
EUROPE = [
    'AUT', 'BEL', 'BGR', 'HRV', 'CYP', 'CZE', 'DNK', 'EST', 'FIN', 'FRA',
    'DEU', 'GRC', 'HUN', 'IRL', 'ITA', 'LVA', 'LTU', 'LUX', 'MLT', 'NLD',
    'NOR', 'POL', 'PRT', 'ROU', 'SVK', 'SVN', 'ESP', 'SWE', 'CHE', 'GBR',
]

# East/Southeast Asia
ASIA = ['JPN', 'KOR', 'CHN', 'SGP', 'THA', 'MYS', 'IDN', 'PHL', 'VNM', 'TWN', 'HKG']

# CCA (Caucasus and Central Asia) — often outliers
CCA = ['ARM', 'AZE', 'GEO', 'KAZ', 'KGZ', 'MNG', 'TJK', 'TKM', 'UZB']


def run_bohn_robustness(df, dep_var, vars_list, label_prefix, results_list):
    """Run Bohn test on a subsample and append results."""
    est = df.dropna(subset=[dep_var] + vars_list)
    if len(est) < 50:
        print(f"  Skipping {label_prefix}: insufficient obs ({len(est)})")
        return
    m, r = fit_and_report(
        est[dep_var].values, est[vars_list].values,
        est['iso3'].values, est['year'].values,
        vars_list, label_prefix
    )
    results_list.append(r)


def run_rg_robustness(df, dep_var, vars_list, label_prefix, results_list):
    """Run r-g regression on a subsample and append results."""
    est = df.dropna(subset=[dep_var] + vars_list)
    if len(est) < 50:
        print(f"  Skipping {label_prefix}: insufficient obs ({len(est)})")
        return
    m, r = fit_and_report(
        est[dep_var].values, est[vars_list].values,
        est['iso3'].values, est['year'].values,
        vars_list, label_prefix
    )
    results_list.append(r)


def main():
    print("=" * 70)
    print("PHASE 6: Robustness Tests")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Reconstruct interactions needed for Bohn test
    df['debt_x_Z1'] = df.get('debt_lag', np.nan) * df.get('Z_1', np.nan)
    df['debt_x_Z2'] = df.get('debt_lag', np.nan) * df.get('Z_2', np.nan)
    df['debt_x_Z3'] = df.get('debt_lag', np.nan) * df.get('Z_3', np.nan)

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    interaction_vars = ['debt_x_Z1', 'debt_x_Z2', 'debt_x_Z3']
    bohn_controls = ['output_gap', 'govt_exp_gap']
    bohn_controls = [c for c in bohn_controls if c in df.columns]
    rg_controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    rg_controls = [c for c in rg_controls if c in df.columns]

    bohn_vars = ['debt_lag'] + demo_vars + interaction_vars + bohn_controls
    rg_vars = demo_vars + rg_controls

    # =================================================================
    # 1. Japan exclusion
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 1: Japan Exclusion")
    print("=" * 70)

    df_no_jpn = df[df['iso3'] != 'JPN'].copy()
    run_bohn_robustness(df_no_jpn, 'primary_bal_gdp', bohn_vars,
                        "Bohn: Excluding Japan", all_results)
    run_rg_robustness(df_no_jpn, 'r_minus_g', rg_vars,
                      "r-g: Excluding Japan", all_results)

    # =================================================================
    # 2. CCA exclusion
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 2: CCA Exclusion")
    print("=" * 70)

    df_no_cca = df[~df['iso3'].isin(CCA)].copy()
    run_bohn_robustness(df_no_cca, 'primary_bal_gdp', bohn_vars,
                        "Bohn: Excluding CCA", all_results)
    run_rg_robustness(df_no_cca, 'r_minus_g', rg_vars,
                      "r-g: Excluding CCA", all_results)

    # =================================================================
    # 3. Subsample: OECD
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 3: OECD Subsample")
    print("=" * 70)

    df_oecd = df[df['iso3'].isin(OECD)].copy()
    run_bohn_robustness(df_oecd, 'primary_bal_gdp', bohn_vars,
                        "Bohn: OECD Only", all_results)
    run_rg_robustness(df_oecd, 'r_minus_g', rg_vars,
                      "r-g: OECD Only", all_results)

    # =================================================================
    # 4. Subsample: Europe
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 4: Europe Subsample")
    print("=" * 70)

    df_europe = df[df['iso3'].isin(EUROPE)].copy()
    run_bohn_robustness(df_europe, 'primary_bal_gdp', bohn_vars,
                        "Bohn: Europe", all_results)
    run_rg_robustness(df_europe, 'r_minus_g', rg_vars,
                      "r-g: Europe", all_results)

    # =================================================================
    # 5. Subsample: Asia
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 5: Asia Subsample")
    print("=" * 70)

    df_asia = df[df['iso3'].isin(ASIA)].copy()
    run_bohn_robustness(df_asia, 'primary_bal_gdp', bohn_vars,
                        "Bohn: East/SE Asia", all_results)
    run_rg_robustness(df_asia, 'r_minus_g', rg_vars,
                      "r-g: East/SE Asia", all_results)

    # =================================================================
    # 6. Subsample: Emerging Markets (non-OECD)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 6: Emerging Markets")
    print("=" * 70)

    df_em = df[~df['iso3'].isin(OECD)].copy()
    run_bohn_robustness(df_em, 'primary_bal_gdp', bohn_vars,
                        "Bohn: Emerging Markets", all_results)
    run_rg_robustness(df_em, 'r_minus_g', rg_vars,
                      "r-g: Emerging Markets", all_results)

    # =================================================================
    # 7. Net debt vs gross debt
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 7: Net Debt vs Gross Debt")
    print("=" * 70)

    if 'govt_net_debt_gdp' in df.columns and df['govt_net_debt_gdp'].notna().sum() >= 200:
        df['net_debt_lag'] = df.groupby('iso3')['govt_net_debt_gdp'].shift(1)
        df['net_debt_x_Z1'] = df['net_debt_lag'] * df['Z_1']
        df['net_debt_x_Z2'] = df['net_debt_lag'] * df['Z_2']
        df['net_debt_x_Z3'] = df['net_debt_lag'] * df['Z_3']

        net_vars = ['net_debt_lag'] + demo_vars + \
                   ['net_debt_x_Z1', 'net_debt_x_Z2', 'net_debt_x_Z3'] + bohn_controls
        run_bohn_robustness(df, 'primary_bal_gdp', net_vars,
                            "Bohn: Net Debt", all_results)
    else:
        print("  Skipping: insufficient net debt data")

    # =================================================================
    # 8. Alternative rate measures
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 8: Alternative Rate Measures for r-g")
    print("=" * 70)

    for rate_var, rate_label in [('govt_bond_10y', 'Bond Yield'),
                                  ('policy_rate', 'Policy Rate'),
                                  ('lending_rate', 'Lending Rate'),
                                  ('real_bond_10y', 'Real Bond Yield')]:
        if rate_var not in df.columns:
            continue
        df[f'rg_{rate_var}'] = df[rate_var] - df['rgdp_growth']
        rg_col = f'rg_{rate_var}'
        est = df.dropna(subset=[rg_col] + rg_vars)
        if len(est) >= 100:
            run_rg_robustness(df, rg_col, rg_vars,
                              f"r-g ({rate_label})", all_results)

    # =================================================================
    # 9. Structural vs primary balance in Bohn test
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 9: Structural Balance Bohn Test")
    print("=" * 70)

    if 'structural_bal_gdp' in df.columns and df['structural_bal_gdp'].notna().sum() >= 200:
        run_bohn_robustness(df, 'structural_bal_gdp', bohn_vars,
                            "Bohn: Structural Balance", all_results)
    else:
        print("  Skipping: insufficient structural balance data")

    # =================================================================
    # 10. Z_1 coefficient comparison across all specs
    # =================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase6_robustness_results.csv", index=False)

        print(f"\n{'=' * 70}")
        print("Z_1 COEFFICIENT COMPARISON ACROSS ALL ROBUSTNESS SPECS")
        print("=" * 70)

        z1_compare = results_df[results_df['variable'] == 'Z_1'][
            ['model', 'coefficient', 'std_error', 'p_value', 'n_obs', 'r_squared']
        ].copy()
        z1_compare['significant'] = z1_compare['p_value'].apply(
            lambda p: '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else '')
        print(z1_compare.to_string(index=False, float_format='%.4f'))
        z1_compare.to_csv(TABLE_DIR / "phase6_z1_comparison.csv", index=False)

        # Also debt_lag comparison for Bohn tests
        print(f"\n{'=' * 70}")
        print("BOHN COEFFICIENT (debt_lag) ACROSS ROBUSTNESS SPECS")
        print("=" * 70)

        debt_compare = results_df[results_df['variable'] == 'debt_lag'][
            ['model', 'coefficient', 'std_error', 'p_value', 'n_obs', 'r_squared']
        ].copy()
        if len(debt_compare) > 0:
            debt_compare['significant'] = debt_compare['p_value'].apply(
                lambda p: '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else '')
            print(debt_compare.to_string(index=False, float_format='%.4f'))
            debt_compare.to_csv(TABLE_DIR / "phase6_bohn_comparison.csv", index=False)

    # =================================================================
    # 11. Pension spending control (OECD subsample)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Test 11: Pension Spending Control (OECD)")
    print("=" * 70)

    if 'pension_spending_gdp' in df.columns:
        df_pension = df[df['pension_spending_gdp'].notna()].copy()
        print(f"  Pension data: {len(df_pension)} obs, {df_pension['iso3'].nunique()} countries")

        if len(df_pension) >= 100:
            pension_bohn = bohn_vars + ['pension_spending_gdp']
            run_bohn_robustness(df_pension, 'primary_bal_gdp', pension_bohn,
                                "Bohn: With Pension Spending", all_results)

            pension_rg = rg_vars + ['pension_spending_gdp']
            run_rg_robustness(df_pension, 'r_minus_g', pension_rg,
                              "r-g: With Pension Spending", all_results)

    # Final save (updated with pension tests)
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase6_robustness_results.csv", index=False)
        print(f"\n{'=' * 70}")
        print(f"Saved: {TABLE_DIR / 'phase6_robustness_results.csv'}")
        print(f"  {len(results_df)} rows across {results_df['model'].nunique()} models")

    return results_df if all_results else pd.DataFrame()


if __name__ == "__main__":
    results = main()
